# Copyright (C) 2020  GreenWaves Technologies, SAS

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

# square root coefficients in 17.15 fixed point format

import numpy as np

# Q16 (uint16) LUT for sigmoid function
sigmoid_table = np.array(
    [32768, 33451, 34133, 34813, 35493, 36169, 36843, 37513, 38180, 38841, 39498,
     40149, 40794, 41432, 42064, 42688, 43304, 43912, 44511, 45102, 45683, 46255,
     46817, 47369, 47911, 48443, 48964, 49475, 49975, 50464, 50942, 51409, 51865,
     52311, 52745, 53169, 53581, 53983, 54374, 54755, 55125, 55485, 55834, 56174,
     56503, 56823, 57133, 57433, 57724, 58007, 58280, 58544, 58800, 59048, 59288,
     59519, 59743, 59959, 60168, 60370, 60565, 60753, 60935, 61110, 61279, 61441,
     61599, 61750, 61896, 62036, 62172, 62302, 62428, 62549, 62666, 62778, 62886,
     62990, 63090, 63186, 63279, 63368, 63454, 63536, 63615, 63691, 63765, 63835,
     63903, 63968, 64030, 64090, 64148, 64204, 64257, 64308, 64357, 64405, 64450,
     64494, 64536, 64576, 64614, 64652, 64687, 64721, 64754, 64786, 64816, 64845,
     64873, 64900, 64926, 64950, 64974, 64997, 65019, 65039, 65060, 65079, 65097,
     65115, 65132, 65149, 65164, 65179, 65194, 65208, 65221, 65234, 65246, 65258,
     65269, 65280, 65291, 65301, 65310, 65319, 65328, 65337, 65345, 65352, 65360,
     65367, 65374, 65381, 65387, 65393, 65399, 65404, 65410, 65415, 65420, 65425,
     65429, 65433, 65438, 65442, 65445, 65449, 65453, 65456, 65459, 65462, 65465,
     65468, 65471, 65474, 65476, 65479, 65481, 65483, 65485, 65488, 65489, 65491,
     65493, 65495, 65497, 65498, 65500, 65501, 65503, 65504, 65505, 65507, 65508,
     65509, 65510, 65511, 65512, 65513, 65514, 65515, 65516, 65517, 65517, 65518,
     65519, 65520, 65520, 65521, 65522, 65522, 65523, 65523, 65524, 65524, 65525,
     65525, 65526, 65526, 65526, 65527, 65527, 65528, 65528, 65528, 65529, 65529,
     65529, 65529, 65530, 65530, 65530, 65530, 65531, 65531, 65531, 65531, 65531,
     65532, 65532, 65532, 65532, 65532, 65532, 65533, 65533, 65533, 65533, 65533,
     65533, 65533, 65533, 65534, 65534, 65534, 65534, 65534, 65534, 65534, 65534,
     65534, 65534, 65535], dtype=np.int32)

sigmoid_table_float = np.array(
    [0.5000000000000000, 0.5078118642792044, 0.5156199157230156, 0.5234203489363240, 0.5312093733737563,
	 0.5389832206876841, 0.5467381519846138, 0.5544704649604273, 0.5621765008857981, 0.5698526514141571,
	 0.5774953651858118, 0.5851011542032312, 0.5926665999540697, 0.6001883592602050, 0.6076631698328917,
	 0.6150878555160665, 0.6224593312018546, 0.6297746074044134, 0.6370307944803831, 0.6442251064863697,
	 0.6513548646660542, 0.6584175005616830, 0.6654105587468140, 0.6723316991792860, 0.6791786991753930,
	 0.6859494550081925, 0.6926419831347361, 0.6992544210587585, 0.7057850278370112, 0.7122321842389470,
	 0.7185943925708561, 0.7248702761768248, 0.7310585786300049, 0.7371581626286834, 0.7431680086124811,
	 0.7490872131147275, 0.7549149868676283, 0.7606506526772884, 0.7662936430859597, 0.7718434978390747,
     0.7772998611746911, 0.7826624789529376, 0.7879311956428947, 0.7931059511841119, 0.7981867777396212,
     0.8031737963569016, 0.8080672135527632, 0.8128673178375735, 0.8175744761936437, 0.8221891305219503,
	 0.8267117940706734, 0.8311430478583168, 0.8354835371034369, 0.8397339676722393, 0.8438951025545426,
	 0.8479677583778257, 0.8519528019683106, 0.8558511469672594, 0.8596637505099167, 0.8633916099737806,
	 0.8670357598021706, 0.8705972684083610, 0.8740772351648677, 0.8774767874818407, 0.8807970779778823,
	 0.8840392817460332, 0.8872045937171068, 0.8902942261220291, 0.8933094060543487, 0.8962513731336202,
	 0.8991213772699436, 0.9019206765295400, 0.9046505351008906, 0.9073122213606241, 0.9099070060380482,
	 0.9124361604769414, 0.9149009549929797, 0.9173026573249612, 0.9196425311777929, 0.9219218348550491,
	 0.9241418199787566, 0.9263037302939515, 0.9284088005554476, 0.9304582554941719, 0.9324533088603709,
     0.9343951625409306, 0.9362850057480346, 0.9381240142763573, 0.9399133498259924, 0.9416541593883190,
	 0.9433475746920261, 0.9449947117065459, 0.9465966702001757, 0.9481545333502200, 0.9496693674025232,
	 0.9511422213778251, 0.9525741268224334, 0.9539660976007597, 0.9553191297273498, 0.9566342012360932,
	 0.9579122720843811, 0.9591542840900507, 0.9603611608990300, 0.9615338079816756, 0.9626731126558706,
	 0.9637799441350253, 0.9648551535992067, 0.9658995742876885, 0.9669140216112958, 0.9678992932829918,
	 0.9688561694652216, 0.9697854129326061, 0.9706877692486436, 0.9715639669551474, 0.9724147177732099,
	 0.9732407168145568, 0.9740426428022031, 0.9748211582993981, 0.9755769099458929, 0.9763105287006256,
	 0.9770226300899744, 0.9777138144607745, 0.9783846672373524, 0.9790357591818724, 0.9796676466573412,
	 0.9802808718926534, 0.9808759632491112, 0.9814534354878828, 0.9820137900379085, 0.9825575152637948,
	 0.9830850867332734, 0.9835969674838370, 0.9840936082881853, 0.9845754479181588, 0.9850429134068500,
	 0.9854964203086180, 0.9859363729567544, 0.9863631647185701, 0.9867771782476948, 0.9871787857334058,
	 0.9875683491468141, 0.9879462204837635, 0.9883127420043056, 0.9886682464686385, 0.9890130573694068,
	 0.9893474891602738, 0.9896718474806949, 0.9899864293768268, 0.9902915235185259, 0.9905874104123901,
	 0.9908743626108194, 0.9911526449170697, 0.9914225145862880, 0.9916842215225259, 0.9919380084717284,
	 0.9921841112107114, 0.9924227587321393, 0.9926541734255229, 0.9928785712542649, 0.9930961619287780,
	 0.9933071490757153, 0.9935117304033434, 0.9937100978631075, 0.9939024378074247, 0.9940889311437562,
	 0.9942697534850080, 0.9944450752963090, 0.9946150620382221, 0.9947798743064417, 0.9949396679680339,
	 0.9950945942942779, 0.9952448000901669, 0.9953904278206259, 0.9955316157335089, 0.9956684979794361,
	 0.9958012047285281, 0.9959298622841040, 0.9960545931933981, 0.9961755163553627, 0.9962927471256138,
	 0.9964063974185798, 0.9965165758069192, 0.9966233876182606, 0.9967269350293284, 0.9968273171575148,
	 0.9969246301499518, 0.9970189672701452, 0.9971104189822263, 0.9971990730328789, 0.9972850145309949,
	 0.9973683260251155, 0.9974490875787134, 0.9975273768433653, 0.9976032691298710, 0.9976768374773688,
	 0.9977481527204970, 0.9978172835546547, 0.9978842965994068, 0.9979492564600826, 0.9980122257876160,
	 0.9980732653366725, 0.9981324340221067, 0.9981897889737974, 0.9982453855899006, 0.9982992775885648,
	 0.9983515170581486, 0.9984021545059827, 0.9984512389057120, 0.9984988177432630, 0.9985449370614672,
	 0.9985896415033819, 0.9986329743543437, 0.9986749775827868, 0.9987156918798661, 0.9987551566979116,
	 0.9987934102877537, 0.9988304897349445, 0.9988664309949126, 0.9989012689270753, 0.9989350373279441,
	 0.9989677689632452, 0.9989994955990900, 0.9990302480322174, 0.9990600561193372, 0.9990889488055994,
	 0.9991169541522155, 0.9991440993632549, 0.9991704108116409, 0.9991959140643708, 0.9992206339069791,
	 0.9992445943672705, 0.9992678187383391, 0.9992903296008995, 0.9993121488449448, 0.9993332976907565,
	 0.9993537967092804, 0.9993736658418905, 0.9993929244195585, 0.9994115911814431, 0.9994296842929216,
	 0.9994472213630764, 0.9994642194616526, 0.9994806951355050, 0.9994966644245472, 0.9995121428772178,
	 0.9995271455654797, 0.9995416870993650, 0.9995557816410785, 0.9995694429186754, 0.9995826842393223,
	 0.9995955185021579, 0.9996079582107623, 0.9996200154852480, 0.9996317020739840, 0.9996430293649623,
	 0.9996540083968218])

NEAREST = True


def sigmoid_lut(x, qtype=None, q16_out=False):
    """ Lookup table based sigmoid.
        Input is in Q12
        Output is in Q15"""
    del qtype
    # Scale [-8:8] into [-10.7:10.7] --> *3/4
    result = np.empty_like(x)
    if not NEAREST:
        abs_x = (np.abs(x) * 3) >> 9  # input in Q12
        abs_x_masked = np.where(abs_x >= 255, 0, abs_x)
        ua = sigmoid_table[abs_x_masked]
        ub = sigmoid_table[abs_x_masked+1]
        ut = abs_x & 0xFF
        result = np.where(abs_x >= 255,
                          0x7FFF << 10,
                          (ua << 9) + ut * (ub-ua))  # Q16*Q8 = Q24
        result = np.where(x > 0, result + (1 << 9),
                          (1 << (9+16))-result+(1 << 9)-1)
        if q16_out:
            return result >> 9
        return result >> 10
    else:
        abs_x = (np.abs(x) * 3) >> 9  # input in Q12
        abs_x_masked = np.where(abs_x >= 255, 0, abs_x)
        result = np.where(abs_x >= 255,
                          0xFFFF,
                          sigmoid_table[abs_x_masked])  # Q16*Q8 = Q24
        result = np.where(x > 0, result, (1 << 16)-result)
        if q16_out:
            return result
        return result >> 1


def tanh_lut(x, qtype=None):
    """ Lookup table based tanh.
        Input is in Q12
        Output is in Q15"""
    del qtype
    # Scale [-8:8] into [-10.7:10.7] --> *3/4
    result = np.empty_like(x)
    if not NEAREST:
        abs_x = (np.abs(x) * 3) >> 8  # 2*abs_x
        abs_x_masked = np.where(abs_x >= 255, 0, abs_x)
        ua = sigmoid_table[abs_x_masked]
        ub = sigmoid_table[abs_x_masked+1]
        ut = abs_x & 0xFF
        result = np.where(abs_x >= 255,
                          0xFFFF << 8,
                          (ua << 8) + ut * (ub-ua))  # Q16*Q8 = Q24
        result = np.where(x > 0,
                          result - (1 << (9+14)) + (1 << (9-2)),
                          -result + (1 << (9+14)) + (1 << (9-2)) - 1)
        return result >> 8
    else:
        abs_x = (np.abs(x) * 3) >> 8  # 2*abs_x
        abs_x_masked = np.where(abs_x >= 255, 0, abs_x)
        result = np.where(abs_x >= 255,
                          0xFFFF,
                          sigmoid_table[abs_x_masked])
        result = np.where(x > 0, result - (1 << 15), -result + (1 << 15))
        return result

def sigmoid_lut_float(x, dtype=None):
    """ Lookup table based sigmoid float version """
    if dtype is None:
        dtype = np.float32
    x_int = np.round(x * 2**(12)).astype(np.int32)
    abs_x = np.abs(x_int) >> 7
    abs_x_masked = np.where(abs_x >= 255, 0, abs_x)
    result = np.where(abs_x >= 255,
                      1.0,
                      sigmoid_table_float[abs_x_masked])
    # cast to actual dtype
    return np.where(x > 0, result.astype(dtype), 1 - result.astype(dtype))


def tanh_lut_float(x, dtype=None):
    """ Lookup table based tanh float version """
    if dtype is None:
        dtype = np.float32
    x_int = np.round(x * 2**(12)).astype(np.int32)
    abs_x = np.abs(x_int) >> 6
    abs_x_masked = np.where(abs_x >= 255, 0, abs_x)
    result = np.where(abs_x >= 255,
                      1.0,
                      sigmoid_table_float[abs_x_masked])
    # cast to actual dtype
    return np.where(x < 0, - 2*result.astype(dtype) + 1.0, 2*result.astype(dtype) - 1.0)
